{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一、配置环境变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from dotenv import dotenv_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ZHIPUAI_API_KEY\n",
      "DASHSCOPE_API_KEY\n"
     ]
    }
   ],
   "source": [
    "env_variables = dotenv_values('.env')\n",
    "for var in env_variables:\n",
    "    print(var)\n",
    "    os.environ[var] = env_variables[var]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'sk-feddf4791axxxxxxx21f8cc335e83'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getenv(\"DASHSCOPE_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 二、配置LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!/usr/bin/env python\n",
    "# -*- coding: utf-8 -*-\n",
    "\n",
    "import os\n",
    "from typing import Dict, List, Optional, Tuple, Union\n",
    "\n",
    "from zhipuai import ZhipuAI\n",
    "import dashscope\n",
    "import random\n",
    "from http import HTTPStatus\n",
    "import dashscope\n",
    "from dashscope import Generation  # 建议dashscope SDK 的版本 >= 1.14.0\n",
    "\n",
    "PROMPT_TEMPLATE = dict(\n",
    "    RAG_PROMPT_TEMPALTE=\"\"\"使用以上下文来回答用户的问题。如果你不知道答案，就说你不知道。总是使用中文回答。\n",
    "        问题: {question}\n",
    "        可参考的上下文：\n",
    "        ···\n",
    "        {context}\n",
    "        ···\n",
    "        如果给定的上下文无法让你做出回答，请回答数据库中没有这个内容，你不知道。\n",
    "        有用的回答:\"\"\",\n",
    "    InternLM_PROMPT_TEMPALTE=\"\"\"先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案，就说你不知道。总是使用中文回答。\n",
    "        问题: {question}\n",
    "        可参考的上下文：\n",
    "        ···\n",
    "        {context}\n",
    "        ···\n",
    "        如果给定的上下文无法让你做出回答，请回答数据库中没有这个内容，你不知道。\n",
    "        有用的回答:\"\"\"\n",
    ")\n",
    "\n",
    "\n",
    "class BaseModel:\n",
    "    def __init__(self, path: str = '') -> None:\n",
    "        self.path = path\n",
    "\n",
    "    def chat(self, prompt: str, history: List[dict], content: str) -> str:\n",
    "        pass\n",
    "\n",
    "    def load_model(self):\n",
    "        pass\n",
    "\n",
    "class GLM4Chat(BaseModel):\n",
    "    def __init__(self, path: str = '', model: str = \"glm-4\") -> None:\n",
    "        super().__init__(path)\n",
    "        self.model = model\n",
    "\n",
    "    def chat(self, prompt: str, history: List[dict], content: str) -> str:\n",
    "        client = ZhipuAI(api_key=os.getenv(\"ZHIPUAI_API_KEY\"))  # 填写您自己的APIKey\n",
    "        history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"glm-4\",  # 填写需要调用的模型名称\n",
    "            messages=history\n",
    "        )\n",
    "        return response.choices[0].message\n",
    "\n",
    "class QwenChat(BaseModel):\n",
    "    def __init__(self, path: str = '', model: str = \"qwen-turbo\") -> None:\n",
    "        super().__init__(path)\n",
    "        dashscope.api_key = os.getenv(\"DASHSCOPE_API_KEY\")\n",
    "        self.model = model\n",
    "\n",
    "    def chat(self, prompt: str, history: List[dict], content: str) -> str:\n",
    "        history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})\n",
    "        response = Generation.call(model=\"qwen-turbo\",\n",
    "                                messages=history,\n",
    "                                # 设置随机数种子seed，如果没有设置，则随机数种子默认为1234\n",
    "                                seed=random.randint(1, 10000),\n",
    "                                # 将输出设置为\"message\"格式\n",
    "                                result_format='message')\n",
    "        if response.status_code == HTTPStatus.OK:\n",
    "            return response.output.choices[0].message[\"content\"]\n",
    "        else:\n",
    "            print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (\n",
    "                response.request_id, response.status_code,\n",
    "                response.code, response.message\n",
    "            ))\n",
    "\n",
    "\n",
    "class OpenAIChat(BaseModel):\n",
    "    def __init__(self, path: str = '', model: str = \"gpt-3.5-turbo-1106\") -> None:\n",
    "        super().__init__(path)\n",
    "        self.model = model\n",
    "\n",
    "    def chat(self, prompt: str, history: List[dict], content: str) -> str:\n",
    "        from openai import OpenAI\n",
    "        client = OpenAI()   \n",
    "        history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})\n",
    "        response = client.chat.completions.create(\n",
    "            model=self.model,\n",
    "            messages=history,\n",
    "            max_tokens=150,\n",
    "            temperature=0.1\n",
    "        )\n",
    "        return response.choices[0].message.content\n",
    "\n",
    "class InternLMChat(BaseModel):\n",
    "    def __init__(self, path: str = '') -> None:\n",
    "        super().__init__(path)\n",
    "        self.load_model()\n",
    "\n",
    "    def chat(self, prompt: str, history: List = [], content: str='') -> str:\n",
    "        prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)\n",
    "        response, history = self.model.chat(self.tokenizer, prompt, history)\n",
    "        return response\n",
    "\n",
    "\n",
    "    def load_model(self):\n",
    "        import torch\n",
    "        from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)\n",
    "        self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'你好！有什么我可以帮助你的吗？'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "qwenchat = QwenChat()\n",
    "qwenchat.chat(history=[], prompt=\"hello\", content=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 三、设置最基础的向量数据库VectorStore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "from typing import Dict, List, Optional, Tuple, Union\n",
    "import json\n",
    "from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "class VectorStore():\n",
    "    def __init__(self, document: List[str] = ['']) -> None:\n",
    "        self.document = document\n",
    "\n",
    "    def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:\n",
    "        \n",
    "        self.vectors = []\n",
    "        for doc in tqdm(self.document, desc=\"Calculating embeddings\"):\n",
    "            self.vectors.append(EmbeddingModel.get_embedding(doc))\n",
    "        return self.vectors\n",
    "\n",
    "    def batch_split_list(self, lst,batch_size):\n",
    "        return [lst[i:i + batch_size] for i in range(0, len(lst), batch_size)]\n",
    "\n",
    "    def get_vector_batch(self, EmbeddingModel: BaseEmbeddings, batch:int) -> List[List[float]]:\n",
    "\n",
    "        self.vectors = []\n",
    "        self.document = self.batch_split_list(self.document, batch)\n",
    "        for doc in tqdm(self.document, desc=\"Calculating embeddings\"):\n",
    "            self.vectors.extend(EmbeddingModel.get_embeddings(doc))\n",
    "        return self.vectors\n",
    "\n",
    "    def persist(self, path: str = 'storage'):\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        with open(f\"{path}/doecment.json\", 'w', encoding='utf-8') as f:\n",
    "            json.dump(self.document, f, ensure_ascii=False)\n",
    "        if self.vectors:\n",
    "            with open(f\"{path}/vectors.json\", 'w', encoding='utf-8') as f:\n",
    "                json.dump(self.vectors, f)\n",
    "\n",
    "    def load_vector(self, path: str = 'storage'):\n",
    "        with open(f\"{path}/vectors.json\", 'r', encoding='utf-8') as f:\n",
    "            self.vectors = json.load(f)\n",
    "        with open(f\"{path}/doecment.json\", 'r', encoding='utf-8') as f:\n",
    "            self.document = json.load(f)\n",
    "\n",
    "    def get_similarity(self, vector1: List[float], vector2: List[float]) -> float:\n",
    "        return BaseEmbeddings.cosine_similarity(vector1, vector2)\n",
    "\n",
    "    def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:\n",
    "        query_vector = EmbeddingModel.get_embedding(query)\n",
    "\n",
    "        end_time = time.time()\n",
    "\n",
    "        result = np.array([self.get_similarity(query_vector, vector)\n",
    "                          for vector in self.vectors])\n",
    "        print(' 检索 cost %f second' % (time.time() - end_time))\n",
    "        return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "vector = VectorStore()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "vector.load_vector('../../storage/github_data') # 加载本地的知识库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1024"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(vector.vectors[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 四、定义embedding算法\n",
    "NOTICE：\n",
    "1.需要注意的是PaddleEmbedding()类中向量维度为768，而ZhipuEmbedding()类中向量维度为1024\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from RAG.PaddleEmbedding import PaddleEmbedding\n",
    "embedding = ZhipuEmbedding()# 创建EmbeddingModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 检索 cost 0.005692 second\n",
      "知识库输出：包含了实际的文件，而暂存区是一个准备好下次提交的文件列表。\n",
      "提交更改\n",
      "提交是Git中的基本操作，它会将暂存区的更改记录到仓库中。每次提交都会在仓库中创建一个快照，并允许之后恢复到该状态。\n",
      "bash\n",
      "gitcommit-m\"描述性的提交信息\"\n",
      "查看提交历史\n",
      "要查看提交历史，可以使用gitlog命令：\n",
      "gitlog\n",
      "还可以使用多种选项来定制显示的日志输出。\n",
      "撤销操作\n",
      "如果需要撤销操作，Git提供了几个命令：\n",
      "\n",
      "撤销工作区的修改：\n",
      "\n",
      "bash\n",
      "gitcheckout--<文件名>\n",
      "\n",
      "撤销暂存区的文件：\n",
      "\n",
      "bash\n",
      "gitresetHEAD<文件名>\n",
      "\n",
      "撤销提交（创建一个新的提交来撤销之前的提交）：\n",
      "\n",
      "bash\n",
      "gitrevert<提交ID>\n",
      "标签管理\n",
      "标签是指向特定提交的引用，通常用于版本发布。创建一个新标签：\n",
      "bash\n",
      "gittag<标签名>\n",
      "列出所有标签：\n",
      "bash\n",
      "gittag\n",
      "删除标签：\n",
      "bash\n",
      "gittag-d<标签名>\n",
      "查看标签信息：\n",
      "bash\n",
      "gitshow<标签名>\n",
      "推送标签到远程仓库：\n",
      "bash\n",
      "gitpushorigin<标签名>\n",
      "分支管理\n",
      "分支的概念\n",
      "在Git中，分支是用来隔离开发工作的。每个分支都是一个独立的开发环境，互不影响。分支可以很方便地被创建和合并，因此许多开发者使用分支来进行特性开发、修复bug或者尝试新想法。\n",
      "Git的一个核心概念是几乎所有操作都是本地执行的，分支也不例外。这意味着你在本地创建或切换分支，不需要与远程仓库进行通信。\n",
      "创建与合并分支\n",
      "在Git中创建新分支可以使用gitbranch命令，合并分支则使用gitmerge命令。\n",
      "```bash\n",
      "创建新分支\n",
      "\n"
     ]
    }
   ],
   "source": [
    "question = 'Git中的文件有哪几种状态?'\n",
    "content = vector.query(question, EmbeddingModel=embedding, k=2)[0]\n",
    "print(f'知识库输出：{content}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "知识库输出：在Git中，文件主要有三种状态：\n",
      "\n",
      "1. 工作区（Working Directory）：这是你的开发环境，包含了你当前正在编辑的文件。当你从仓库检出代码或者直接修改文件时，这些修改都在工作区。\n",
      "\n",
      "2. 暂存区（Staging Area, Index）：也称为索引，是一个缓冲区，当你对工作区的文件做了修改后，可以选择哪些修改添加到暂存区，准备提交。暂存区的文件是未提交但已准备好的更改。\n",
      "\n",
      "3. 仓库（Repository）：Git的中央存储库，包含了所有的提交历史、分支和标签。提交后的文件状态会保存在仓库中，可以通过提交ID访问。\n",
      "\n",
      "当提到Git中的文件状态时，通常指的是这三种状态。工作区和暂存区是本地的状态，仓库则是版本控制的状态。\n"
     ]
    }
   ],
   "source": [
    "chat = QwenChat()\n",
    "print(f\"知识库输出：{chat.chat(question, [], content)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
